{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adadelta\n",
    ":label:`sec_adadelta`\n",
    "\n",
    "\n",
    "Adadelta是AdaGrad的另一种变体（ :numref:`sec_adagrad`），\n",
    "主要区别在于前者减少了学习率适应坐标的数量。\n",
    "此外，广义上Adadelta被称为没有学习率，因为它使用变化量本身作为未来变化的校准。\n",
    "Adadelta算法是在 :cite:`Zeiler.2012`中提出的。\n",
    "\n",
    "## Adadelta算法\n",
    "\n",
    "简而言之，Adadelta使用两个状态变量，$\\mathbf{s}_t$用于存储梯度二阶导数的漏平均值，$\\Delta\\mathbf{x}_t$用于存储模型本身中参数变化二阶导数的泄露平均值。请注意，为了与其他出版物和实现的兼容性，我们使用作者的原始符号和命名（没有其他真正理由为什么应该使用不同的希腊变量来表示在动量中用于相同用途的参数，即AdaGrad、RMSProp和Adadelta）。\n",
    "\n",
    "以下是Adadelta的技术细节。鉴于参数du jour是$\\rho$，我们获得了与 :numref:`sec_rmsprop`类似的以下泄漏更新：\n",
    "\n",
    "$$\\begin{aligned}\n",
    "    \\mathbf{s}_t & = \\rho \\mathbf{s}_{t-1} + (1 - \\rho) \\mathbf{g}_t^2.\n",
    "\\end{aligned}$$\n",
    "\n",
    "与 :numref:`sec_rmsprop`的区别在于，我们使用重新缩放的梯度$\\mathbf{g}_t'$执行更新，即\n",
    "\n",
    "$$\\begin{aligned}\n",
    "    \\mathbf{x}_t  & = \\mathbf{x}_{t-1} - \\mathbf{g}_t'. \\\\\n",
    "\\end{aligned}$$\n",
    "\n",
    "那么，调整后的梯度$\\mathbf{g}_t'$是什么？我们可以按如下方式计算它：\n",
    "\n",
    "$$\\begin{aligned}\n",
    "    \\mathbf{g}_t' & = \\frac{\\sqrt{\\Delta\\mathbf{x}_{t-1} + \\epsilon}}{\\sqrt{{\\mathbf{s}_t + \\epsilon}}} \\odot \\mathbf{g}_t, \\\\\n",
    "\\end{aligned}$$\n",
    "\n",
    "其中$\\Delta \\mathbf{x}_{t-1}$是重新缩放梯度的平方$\\mathbf{g}_t'$的泄漏平均值。我们将$\\Delta \\mathbf{x}_{0}$初始化为$0$，然后在每个步骤中使用$\\mathbf{g}_t'$更新它，即\n",
    "\n",
    "$$\\begin{aligned}\n",
    "    \\Delta \\mathbf{x}_t & = \\rho \\Delta\\mathbf{x}_{t-1} + (1 - \\rho) {\\mathbf{g}_t'}^2,\n",
    "\\end{aligned}$$\n",
    "\n",
    "和$\\epsilon$（例如$10^{-5}$这样的小值）是为了保持数字稳定性而加入的。\n",
    "\n",
    "## 代码实现\n",
    "\n",
    "Adadelta需要为每个变量维护两个状态变量，即$\\mathbf{s}_t$和$\\Delta\\mathbf{x}_t$。这将产生以下实施。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/plot-utils\n",
    "%load ../utils/Functions.java\n",
    "%load ../utils/GradDescUtils.java\n",
    "%load ../utils/Accumulator.java\n",
    "%load ../utils/StopWatch.java\n",
    "%load ../utils/Training.java\n",
    "%load ../utils/TrainingChapter11.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDList initAdadeltaStates(int featureDimension) {\n",
    "    NDManager manager = NDManager.newBaseManager();\n",
    "    NDArray sW = manager.zeros(new Shape(featureDimension, 1));\n",
    "    NDArray sB = manager.zeros(new Shape(1));\n",
    "    NDArray deltaW = manager.zeros(new Shape(featureDimension, 1));\n",
    "    NDArray deltaB = manager.zeros(new Shape(1));\n",
    "    return new NDList(sW, deltaW, sB, deltaB);\n",
    "}\n",
    "\n",
    "public class Optimization {\n",
    "    public static void adadelta(NDList params, NDList states, Map<String, Float> hyperparams) {\n",
    "        float rho = hyperparams.get(\"rho\");\n",
    "        float eps = (float) 1e-5;\n",
    "        for (int i = 0; i < params.size(); i++) {\n",
    "            NDArray param = params.get(i);\n",
    "            NDArray state = states.get(2 * i);\n",
    "            NDArray delta = states.get(2 * i + 1);\n",
    "            // Update parameter, state, and delta\n",
    "            // In-place updates with the '__'i methods (ex. muli)\n",
    "            // state = rho * state + (1 - rho) * param.gradient^2\n",
    "            state.muli(rho).addi(param.getGradient().square().mul(1 - rho));\n",
    "            // rescaledGradient = ((delta + eps)^(1/2) / (state + eps)^(1/2)) * param.gradient\n",
    "            NDArray rescaledGradient = delta.add(eps).sqrt()\n",
    "                .div(state.add(eps).sqrt()).mul(param.getGradient());\n",
    "            // param -= rescaledGradient\n",
    "            param.subi(rescaledGradient);\n",
    "            // delta = rho * delta + (1 - rho) * g^2\n",
    "            delta.muli(rho).addi(rescaledGradient.square().mul(1 - rho));\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "对于每次参数更新，选择$\\rho = 0.9$相当于10个半衰期。由此我们得到：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "AirfoilRandomAccess airfoil = TrainingChapter11.getDataCh11(10, 1500);\n",
    "\n",
    "public TrainingChapter11.LossTime trainAdadelta(float rho, int numEpochs) throws IOException, TranslateException {\n",
    "    int featureDimension = airfoil.getColumnNames().size();\n",
    "    Map<String, Float> hyperparams = new HashMap<>();\n",
    "    hyperparams.put(\"rho\", rho);\n",
    "    return TrainingChapter11.trainCh11(Optimization::adadelta, \n",
    "                                       initAdadeltaStates(featureDimension), \n",
    "                                       hyperparams, airfoil, \n",
    "                                       featureDimension, numEpochs);\n",
    "}\n",
    "\n",
    "TrainingChapter11.LossTime lossTime = trainAdadelta(0.9f, 2);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "为了简洁实现，我们只需使用`Trainer`类中的`adadelta`算法。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Optimizer adadelta = Optimizer.adadelta().optRho(0.9f).build();\n",
    "\n",
    "TrainingChapter11.trainConciseCh11(adadelta, airfoil, 2);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "\n",
    "* Adadelta没有学习率参数。相反，它使用参数本身的变化率来调整学习率。\n",
    "* Adadelta需要两个状态变量来存储梯度的二阶导数和参数的变化。\n",
    "* Adadelta使用泄漏的平均值来保持对适当统计数据的运行估计。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 调整$\\rho$的值，会发生什么？\n",
    "1. 展示如何在不使用$\\mathbf{g}_t'$的情况下实现算法。为什么这是个好主意？\n",
    "1. Adadelta真的是学习率为0吗？你能找到Adadelta无法解决的优化问题吗？\n",
    "1. 将Adadelta的收敛行为与AdaGrad和RMSProp进行比较。\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
